{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import time\n",
    "import gzip\n",
    "import urllib\n",
    "import collections\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _read32(bytestream):\n",
    "    dt = np.dtype(np.uint32).newbyteorder('>')\n",
    "    return np.frombuffer(bytestream.read(4), dtype=dt)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_images(f):\n",
    "    \"\"\"Extract the images into a 4D uint8 numpy array [index, y, x, depth].\"\"\"\n",
    "    print('Extracting', f.name)\n",
    "    with gzip.GzipFile(fileobj=f) as bytestream:\n",
    "        magic = _read32(bytestream)\n",
    "        if magic != 2051:\n",
    "            raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))\n",
    "        num_images = _read32(bytestream)\n",
    "        rows = _read32(bytestream)\n",
    "        cols = _read32(bytestream)\n",
    "        buf = bytestream.read(rows * cols * num_images)\n",
    "        data = np.frombuffer(buf, dtype=np.uint8)\n",
    "        data = data.reshape(num_images, rows, cols, 1)\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dense_to_one_hot(labels_dense, num_classes):\n",
    "    \"\"\"Convert class labels from scalars to one-hot vectors.\"\"\"\n",
    "    num_labels = labels_dense.shape[0]\n",
    "    index_offset = np.arange(num_labels) * num_classes\n",
    "    labels_one_hot = np.zeros((num_labels, num_classes))\n",
    "    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1\n",
    "    return labels_one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_labels(f, one_hot=False, num_classes=10):\n",
    "    \"\"\"Extract the labels into a 1D uint8 numpy array [index].\"\"\"\n",
    "    print('Extracting', f.name)\n",
    "    with gzip.GzipFile(fileobj=f) as bytestream:\n",
    "        magic = _read32(bytestream)\n",
    "        if magic != 2049:\n",
    "            raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, f.name))\n",
    "        num_items = _read32(bytestream)\n",
    "        buf = bytestream.read(num_items)\n",
    "        labels = np.frombuffer(buf, dtype=np.uint8)\n",
    "        if one_hot:\n",
    "            return dense_to_one_hot(labels, num_classes)\n",
    "        return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataSet(object):\n",
    "    \"\"\"Container class for a dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32, reshape=True, seed=None):\n",
    "        \"\"\"Construct a DataSet.\"\"\"\n",
    "        seed1, seed2 = tf.get_seed(seed)\n",
    "        # If op level seed is not set, use whatever graph level seed is returned\n",
    "        np.random.seed(seed1 if seed is None else seed2)\n",
    "        dtype = tf.as_dtype(dtype).base_dtype\n",
    "        if dtype not in (tf.uint8, tf.float32):\n",
    "            raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype)\n",
    "        if fake_data:\n",
    "            self._num_examples = 10000\n",
    "            self.one_hot = one_hot\n",
    "        else:\n",
    "            assert images.shape[0] == labels.shape[0], ('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))\n",
    "            self._num_examples = images.shape[0]\n",
    "\n",
    "            # Convert shape from [num examples, rows, columns, depth]\n",
    "            # to [num examples, rows*columns] (assuming depth == 1)\n",
    "            if reshape:\n",
    "                assert images.shape[3] == 1\n",
    "                images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])\n",
    "            if dtype == tf.float32:\n",
    "                # Convert from [0, 255] -> [0.0, 1.0].\n",
    "                images = images.astype(np.float32)\n",
    "                images = np.multiply(images, 1.0 / 255.0)\n",
    "        self._images = images\n",
    "        self._labels = labels\n",
    "        self._epochs_completed = 0\n",
    "        self._index_in_epoch = 0\n",
    "\n",
    "    @property\n",
    "    def images(self):\n",
    "        return self._images\n",
    "\n",
    "    @property\n",
    "    def labels(self):\n",
    "        return self._labels\n",
    "\n",
    "    @property\n",
    "    def num_examples(self):\n",
    "        return self._num_examples\n",
    "\n",
    "    @property\n",
    "    def epochs_completed(self):\n",
    "        return self._epochs_completed\n",
    "\n",
    "    def next_batch(self, batch_size, fake_data=False, shuffle=True):\n",
    "        \"\"\"Return the next `batch_size` examples from this data set.\"\"\"\n",
    "        if fake_data:\n",
    "            fake_image = [1] * 784\n",
    "            if self.one_hot:\n",
    "                fake_label = [1] + [0] * 9\n",
    "            else:\n",
    "                fake_label = 0\n",
    "            return [fake_image for _ in range(batch_size)], [fake_label for _ in range(batch_size)]\n",
    "        start = self._index_in_epoch\n",
    "        # Shuffle for the first epoch\n",
    "        if self._epochs_completed == 0 and start == 0 and shuffle:\n",
    "            perm0 = np.arange(self._num_examples)\n",
    "            np.random.shuffle(perm0)\n",
    "            self._images = self.images[perm0]\n",
    "            self._labels = self.labels[perm0]\n",
    "        # Go to the next epoch\n",
    "        if start + batch_size > self._num_examples:\n",
    "            # Finished epoch\n",
    "            self._epochs_completed += 1\n",
    "            # Get the rest examples in this epoch\n",
    "            rest_num_examples = self._num_examples - start\n",
    "            images_rest_part = self._images[start:self._num_examples]\n",
    "            labels_rest_part = self._labels[start:self._num_examples]\n",
    "            # Shuffle the data\n",
    "            if shuffle:\n",
    "                perm = np.arange(self._num_examples)\n",
    "                np.random.shuffle(perm)\n",
    "                self._images = self.images[perm]\n",
    "                self._labels = self.labels[perm]\n",
    "            # Start next epoch\n",
    "            start = 0\n",
    "            self._index_in_epoch = batch_size - rest_num_examples\n",
    "            end = self._index_in_epoch\n",
    "            images_new_part = self._images[start:end]\n",
    "            labels_new_part = self._labels[start:end]\n",
    "            return np.concatenate((images_rest_part, images_new_part), axis=0), np.concatenate((labels_rest_part, labels_new_part), axis=0)\n",
    "        else:\n",
    "            self._index_in_epoch += batch_size\n",
    "            end = self._index_in_epoch\n",
    "            return self._images[start:end], self._labels[start:end]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def urlretrieve_with_retry(url, filename=None):\n",
    "    return urllib.request.urlretrieve(url, filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def maybe_download(filename, work_directory, source_url):\n",
    "    \"\"\"Download the data from source url, unless it's already here.\"\"\"\n",
    "    if not tf.gfile.Exists(work_directory):\n",
    "        tf.gfile.MakeDirs(work_directory)\n",
    "    filepath = os.path.join(work_directory, filename)\n",
    "    if not tf.gfile.Exists(filepath):\n",
    "        temp_file_name, _ = urlretrieve_with_retry(source_url)\n",
    "        tf.gfile.Copy(temp_file_name, filepath)\n",
    "        with tf.gfile.GFile(filepath) as f:\n",
    "            size = f.size()\n",
    "        print('Successfully downloaded', filename, size, 'bytes.')\n",
    "    return filepath"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32, reshape=True, validation_size=5000, seed=None, source_url=DEFAULT_SOURCE_URL):\n",
    "    if fake_data:\n",
    "        def fake():\n",
    "            return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)\n",
    "\n",
    "        train = fake()\n",
    "        validation = fake()\n",
    "        test = fake()\n",
    "        return Datasets(train=train, validation=validation, test=test)\n",
    "\n",
    "    if not source_url:  # empty string check\n",
    "        source_url = DEFAULT_SOURCE_URL\n",
    "\n",
    "    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'\n",
    "    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'\n",
    "    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'\n",
    "    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'\n",
    "\n",
    "    local_file = maybe_download(TRAIN_IMAGES, train_dir, source_url + TRAIN_IMAGES)\n",
    "    with tf.gfile.Open(local_file, 'rb') as f:\n",
    "        train_images = extract_images(f)\n",
    "\n",
    "    local_file = maybe_download(TRAIN_LABELS, train_dir, source_url + TRAIN_LABELS)\n",
    "    with tf.gfile.Open(local_file, 'rb') as f:\n",
    "        train_labels = extract_labels(f, one_hot=one_hot)\n",
    "\n",
    "    local_file = maybe_download(TEST_IMAGES, train_dir, source_url + TEST_IMAGES)\n",
    "    with tf.gfile.Open(local_file, 'rb') as f:\n",
    "        test_images = extract_images(f)\n",
    "\n",
    "    local_file = maybe_download(TEST_LABELS, train_dir, source_url + TEST_LABELS)\n",
    "    with tf.gfile.Open(local_file, 'rb') as f:\n",
    "        test_labels = extract_labels(f, one_hot=one_hot)\n",
    "\n",
    "    if not 0 <= validation_size <= len(train_images):\n",
    "        raise ValueError('Validation size should be between 0 and {}. Received: {}.'.format(len(train_images), validation_size))\n",
    "\n",
    "    validation_images = train_images[:validation_size]\n",
    "    validation_labels = train_labels[:validation_size]\n",
    "    train_images = train_images[validation_size:]\n",
    "    train_labels = train_labels[validation_size:]\n",
    "\n",
    "    options = dict(dtype=dtype, reshape=reshape, seed=seed)\n",
    "\n",
    "    train = DataSet(train_images, train_labels, **options)\n",
    "    validation = DataSet(validation_images, validation_labels, **options)\n",
    "    test = DataSet(test_images, test_labels, **options)\n",
    "\n",
    "    return Datasets(train=train, validation=validation, test=test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.01\n",
    "max_steps = 1800\n",
    "batch_size = 100\n",
    "input_data_dir = 'input_data'\n",
    "log_dir = 'logs/mnist_deep'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "if tf.gfile.Exists(log_dir):\n",
    "    tf.gfile.DeleteRecursively(log_dir)\n",
    "tf.gfile.MakeDirs(log_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting input_data/train-images-idx3-ubyte.gz\n",
      "Extracting input_data/train-labels-idx1-ubyte.gz\n",
      "Extracting input_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting input_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "data_sets = read_data_sets(input_data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The MNIST dataset has 10 classes, representing the digits 0 through 9.\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# The MNIST images are always 28x28 pixels.\n",
    "IMAGE_SIZE = 28\n",
    "IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE\n",
    "\n",
    "NUM_CHANNELS = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS])\n",
    "y_ = tf.placeholder(tf.int64, [None])\n",
    "x_image = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv2d(x, W):\n",
    "    \"\"\"conv2d returns a 2d convolution layer with full stride.\"\"\"\n",
    "    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')\n",
    "\n",
    "def max_pool_2x2(x):\n",
    "    \"\"\"max_pool_2x2 downsamples a feature map by 2X.\"\"\"\n",
    "    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
    "\n",
    "def weight_variable(shape):\n",
    "    \"\"\"weight_variable generates a weight variable of a given shape.\"\"\"\n",
    "    initial = tf.truncated_normal(shape, stddev=0.1)\n",
    "    return tf.Variable(initial)\n",
    "\n",
    "def bias_variable(shape):\n",
    "    \"\"\"bias_variable generates a bias variable of a given shape.\"\"\"\n",
    "    initial = tf.constant(0.1, shape=shape)\n",
    "    return tf.Variable(initial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_conv1 = weight_variable([5, 5, 1, 32])\n",
    "b_conv1 = bias_variable([32])\n",
    "h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)\n",
    "h_pool1 = max_pool_2x2(h_conv1)\n",
    "\n",
    "W_conv2 = weight_variable([5, 5, 32, 64])\n",
    "b_conv2 = bias_variable([64])\n",
    "h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)\n",
    "h_pool2 = max_pool_2x2(h_conv2)\n",
    "\n",
    "W_fc1 = weight_variable([7 * 7 * 64, 512])\n",
    "b_fc1 = bias_variable([512])\n",
    "h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])\n",
    "h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)\n",
    "\n",
    "keep_prob = tf.placeholder(tf.float32)\n",
    "h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)\n",
    "\n",
    "W_fc2 = weight_variable([512, NUM_CLASSES])\n",
    "b_fc2 = bias_variable([NUM_CLASSES])\n",
    "y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y_conv)\n",
    "cross_entropy = tf.reduce_mean(cross_entropy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)\n",
    "correct_prediction = tf.cast(correct_prediction, tf.float32)\n",
    "accuracy = tf.reduce_mean(correct_prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 100, training accuracy 0.94\n",
      "step 200, training accuracy 0.97\n",
      "step 300, training accuracy 0.96\n",
      "step 400, training accuracy 0.95\n",
      "step 500, training accuracy 0.95\n",
      "step 600, training accuracy 0.97\n",
      "test accuracy 0.9663\n",
      "step 700, training accuracy 0.96\n",
      "step 800, training accuracy 0.98\n",
      "step 900, training accuracy 0.97\n",
      "step 1000, training accuracy 0.98\n",
      "step 1100, training accuracy 0.95\n",
      "step 1200, training accuracy 0.99\n",
      "test accuracy 0.969\n",
      "step 1300, training accuracy 0.97\n",
      "step 1400, training accuracy 0.97\n",
      "step 1500, training accuracy 0.98\n",
      "step 1600, training accuracy 0.96\n",
      "step 1700, training accuracy 0.97\n",
      "step 1800, training accuracy 0.98\n",
      "test accuracy 0.9667\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "\n",
    "    for step in range(max_steps):\n",
    "        batch = data_sets.train.next_batch(batch_size)\n",
    "        train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})\n",
    "        \n",
    "        if (step + 1) % 100 == 0:\n",
    "            train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})\n",
    "            print('step %d, training accuracy %g' % (step + 1, train_accuracy))\n",
    "\n",
    "        if (step + 1) % 600 == 0 or (step + 1) == max_steps:\n",
    "            checkpoint_file = os.path.join(log_dir, 'model.ckpt')\n",
    "            saver.save(sess, checkpoint_file)\n",
    "            accuracy_l = []\n",
    "            for _ in range(20):\n",
    "                batch = data_sets.test.next_batch(500, shuffle=False)\n",
    "                accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}))\n",
    "            print('test accuracy %g' % np.mean(accuracy_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
